/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include <utility>
#include <vector>

#include <gtest/gtest.h>
#include "absl/log/check.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h"
#include "tensorflow/core/runtime_fallback/util/fallback_test_util.h"
#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h"
#include "tensorflow/core/tfrt/runtime/runtime.h"
#include "tensorflow/core/tfrt/utils/thread_pool.h"
#include "tfrt/bef/bef_buffer.h"  // from @tf_runtime
#include "tfrt/bef_executor/bef_file.h"  // from @tf_runtime
#include "tfrt/core_runtime/core_runtime.h"  // from @tf_runtime
#include "tfrt/host_context/async_value.h"  // from @tf_runtime
#include "tfrt/host_context/chain.h"  // from @tf_runtime
#include "tfrt/host_context/function.h"  // from @tf_runtime
#include "tfrt/host_context/host_context.h"  // from @tf_runtime
#include "tfrt/host_context/resource_context.h"  // from @tf_runtime
#include "tfrt/support/ref_count.h"  // from @tf_runtime
#include "tfrt/tracing/tracing.h"  // from @tf_runtime

namespace tensorflow {
namespace {

// Creates a BEF file with a program that runs tfrt_fallback.batch_function with
// a empty function forwarding inputs or outputs.
//
// TODO(b/175648326): Move the function below to the common test utilities for
// BEF.
std::pair<tfrt::BefBuffer, tfrt::RCReference<tfrt::BEFFile>> CreateBefFile(
    absl::string_view file_name, tfrt::HostContext* host) {
  std::string file_path = GetDataDependencyFilepath(absl::StrCat(
      "tensorflow/compiler/mlir/tfrt/tests/tfrt_fallback/", file_name));
  std::string data;
  CHECK_OK(ReadFileToString(Env::Default(), file_path, &data));

  tfrt::BefBuffer bef_buffer(data.begin(), data.end());

  auto bef_file = tfrt::BEFFile::Open(bef_buffer, host->GetKernelRegistry(),
                                      host->diag_handler(), host->allocator());
  CHECK(bef_file);
  return std::make_pair(std::move(bef_buffer), std::move(bef_file));
}

TEST(KernelFallbackCompatTest, CreateOp) {
  tfrt::tracing::SetTracingLevel(tfrt::tracing::TracingLevel::Debug);
  auto runtime =
      tensorflow::tfrt_stub::Runtime::Create(/*num_inter_op_threads=*/4);
  auto* host = runtime->core_runtime()->GetHostContext();
  auto pair = CreateBefFile("create_op.mlir.bef", host);
  auto& bef_file = pair.second;

  tfrt::ResourceContext resource_ctx;
  auto exec_ctx = tfd::CreateFallbackTestExecutionContext(host, &resource_ctx);

  auto chain = tfrt::GetReadyChain();

  auto* func = bef_file->GetFunction("init");
  ASSERT_TRUE(func != nullptr);

  std::vector<tfrt::RCReference<tfrt::AsyncValue>> results;
  results.resize(1);

  func->Execute(exec_ctx, {chain.GetAsyncValue()}, results);
  host->Await(results);

  ASSERT_FALSE(results[0]->IsError()) << results[0]->GetError().message();

  auto* fallback_request_state =
      exec_ctx.request_ctx()
          ->GetDataIfExists<tfd::KernelFallbackCompatRequestState>();

  ASSERT_TRUE(fallback_request_state != nullptr);

  auto* rendezvous = fallback_request_state->rendezvous();
  ASSERT_TRUE(rendezvous != nullptr);

  auto* runner_table = fallback_request_state->runner_table();

  // TODO(tfrt-devs): Create a special key type for OpKernelRunnerCache instead
  // of using tfrt::Location. The key should be generated by higher-level
  // applications such as compiler, and it is higher-level applications'
  // responsibility to make sure the key is unique in an OpKernelRunnerCache
  // instance.
  auto* kernel_runner_add = runner_table->Get(0);
  ASSERT_TRUE(kernel_runner_add);

  auto* kernel_runner_flatmap = runner_table->Get(1);
  ASSERT_TRUE(kernel_runner_flatmap);
}

TEST(KernelFallbackCompatTest, CustomThreadPool) {
  auto runtime =
      tensorflow::tfrt_stub::Runtime::Create(/*num_inter_op_threads=*/4);
  auto* host = runtime->core_runtime()->GetHostContext();

  auto pair = CreateBefFile("custom_thread_pool.mlir.bef", host);
  auto& bef_file = pair.second;

  tfrt::ResourceContext resource_ctx;

  tensorflow::tfrt_stub::TfThreadPool thread_pool(/*name=*/"test",
                                                  /*num_threads=*/1);

  auto exec_ctx = tfd::CreateFallbackTestExecutionContext(host, &resource_ctx,
                                                          &thread_pool);

  auto chain = tfrt::GetReadyChain();

  auto* init_func = bef_file->GetFunction("init");
  ASSERT_TRUE(init_func != nullptr);

  std::vector<tfrt::RCReference<tfrt::AsyncValue>> results;
  results.resize(1);

  init_func->Execute(exec_ctx, {chain.GetAsyncValue()}, results);
  host->Await(results);
  ASSERT_FALSE(results[0]->IsError()) << results[0]->GetError().message();

  auto* run_func = bef_file->GetFunction("run");
  ASSERT_TRUE(run_func != nullptr);

  results.clear();
  results.resize(2);
  run_func->Execute(exec_ctx, {chain.GetAsyncValue()}, results);
  host->Await(results);
  ASSERT_FALSE(results[0]->IsError()) << results[0]->GetError().message();
}

}  // namespace
}  // namespace tensorflow
